# testing/exclusions.py
# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php


import contextlib
import operator
import re

from sqlalchemy import util as sqla_util
from sqlalchemy.util import decorator

from . import config
from . import fixture_functions
from .. import util
from ..util.compat import inspect_getargspec


def skip_if(predicate, reason=None):
    rule = compound()
    pred = _as_predicate(predicate, reason)
    rule.skips.add(pred)
    return rule


def fails_if(predicate, reason=None):
    rule = compound()
    pred = _as_predicate(predicate, reason)
    rule.fails.add(pred)
    return rule


class compound(object):
    def __init__(self):
        self.fails = set()
        self.skips = set()
        self.tags = set()
        self.combinations = {}

    def __add__(self, other):
        return self.add(other)

    def with_combination(self, **kw):
        copy = compound()
        copy.fails.update(self.fails)
        copy.skips.update(self.skips)
        copy.tags.update(self.tags)
        copy.combinations.update((f, kw) for f in copy.fails)
        copy.combinations.update((s, kw) for s in copy.skips)
        return copy

    def add(self, *others):
        copy = compound()
        copy.fails.update(self.fails)
        copy.skips.update(self.skips)
        copy.tags.update(self.tags)
        for other in others:
            copy.fails.update(other.fails)
            copy.skips.update(other.skips)
            copy.tags.update(other.tags)
        return copy

    def not_(self):
        copy = compound()
        copy.fails.update(NotPredicate(fail) for fail in self.fails)
        copy.skips.update(NotPredicate(skip) for skip in self.skips)
        copy.tags.update(self.tags)
        return copy

    @property
    def enabled(self):
        return self.enabled_for_config(config._current)

    def enabled_for_config(self, config):
        for predicate in self.skips.union(self.fails):
            if predicate(config):
                return False
        else:
            return True

    def matching_config_reasons(self, config):
        return [
            predicate._as_string(config)
            for predicate in self.skips.union(self.fails)
            if predicate(config)
        ]

    def include_test(self, include_tags, exclude_tags):
        return bool(
            not self.tags.intersection(exclude_tags)
            and (not include_tags or self.tags.intersection(include_tags))
        )

    def _extend(self, other):
        self.skips.update(other.skips)
        self.fails.update(other.fails)
        self.tags.update(other.tags)
        self.combinations.update(other.combinations)

    def __call__(self, fn):
        if hasattr(fn, "_sa_exclusion_extend"):
            fn._sa_exclusion_extend._extend(self)
            return fn

        @decorator
        def decorate(fn, *args, **kw):
            return self._do(config._current, fn, *args, **kw)

        decorated = decorate(fn)
        decorated._sa_exclusion_extend = self
        return decorated

    @contextlib.contextmanager
    def fail_if(self):
        all_fails = compound()
        all_fails.fails.update(self.skips.union(self.fails))

        try:
            yield
        except Exception as ex:
            all_fails._expect_failure(config._current, ex, None)
        else:
            all_fails._expect_success(config._current, None)

    def _check_combinations(self, combination, predicate):
        if predicate in self.combinations:
            for k, v in combination:
                if (
                    k in self.combinations[predicate]
                    and self.combinations[predicate][k] != v
                ):
                    return False
        return True

    def _do(self, cfg, fn, *args, **kw):
        if len(args) > 1:
            insp = inspect_getargspec(fn)
            combination = list(zip(insp.args[1:], args[1:]))
        else:
            combination = None

        for skip in self.skips:
            if self._check_combinations(combination, skip) and skip(cfg):
                msg = "'%s' : %s" % (
                    fixture_functions.get_current_test_name(),
                    skip._as_string(cfg),
                )
                config.skip_test(msg)

        try:
            return_value = fn(*args, **kw)
        except Exception as ex:
            self._expect_failure(cfg, ex, combination, name=fn.__name__)
        else:
            self._expect_success(cfg, combination, name=fn.__name__)
            return return_value

    def _expect_failure(self, config, ex, combination, name="block"):
        for fail in self.fails:
            if self._check_combinations(combination, fail) and fail(config):
                if sqla_util.py2k:
                    str_ex = unicode(ex).encode(  # noqa: F821
                        "utf-8", errors="ignore"
                    )
                else:
                    str_ex = str(ex)
                print(
                    (
                        "%s failed as expected (%s): %s "
                        % (name, fail._as_string(config), str_ex)
                    )
                )
                break
        else:
            util.raise_from_cause(ex)

    def _expect_success(self, config, combination, name="block"):
        if not self.fails:
            return

        for fail in self.fails:
            if self._check_combinations(combination, fail) and fail(config):
                raise AssertionError(
                    "Unexpected success for '%s' (%s)"
                    % (
                        name,
                        " and ".join(
                            fail._as_string(config) for fail in self.fails
                        ),
                    )
                )


def requires_tag(tagname):
    return tags([tagname])


def tags(tagnames):
    comp = compound()
    comp.tags.update(tagnames)
    return comp


def only_if(predicate, reason=None):
    predicate = _as_predicate(predicate)
    return skip_if(NotPredicate(predicate), reason)


def succeeds_if(predicate, reason=None):
    predicate = _as_predicate(predicate)
    return fails_if(NotPredicate(predicate), reason)


class Predicate(object):
    @classmethod
    def as_predicate(cls, predicate, description=None):
        if isinstance(predicate, compound):
            return cls.as_predicate(predicate.enabled_for_config, description)
        elif isinstance(predicate, Predicate):
            if description and predicate.description is None:
                predicate.description = description
            return predicate
        elif isinstance(predicate, (list, set)):
            return OrPredicate(
                [cls.as_predicate(pred) for pred in predicate], description
            )
        elif isinstance(predicate, tuple):
            return SpecPredicate(*predicate)
        elif isinstance(predicate, sqla_util.string_types):
            tokens = re.match(
                r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate
            )
            if not tokens:
                raise ValueError(
                    "Couldn't locate DB name in predicate: %r" % predicate
                )
            db = tokens.group(1)
            op = tokens.group(2)
            spec = (
                tuple(int(d) for d in tokens.group(3).split("."))
                if tokens.group(3)
                else None
            )

            return SpecPredicate(db, op, spec, description=description)
        elif callable(predicate):
            return LambdaPredicate(predicate, description)
        else:
            assert False, "unknown predicate type: %s" % predicate

    def _format_description(self, config, negate=False):
        bool_ = self(config)
        if negate:
            bool_ = not negate
        return self.description % {
            "driver": config.db.url.get_driver_name()
            if config
            else "<no driver>",
            "database": config.db.url.get_backend_name()
            if config
            else "<no database>",
            "doesnt_support": "doesn't support" if bool_ else "does support",
            "does_support": "does support" if bool_ else "doesn't support",
        }

    def _as_string(self, config=None, negate=False):
        raise NotImplementedError()


class BooleanPredicate(Predicate):
    def __init__(self, value, description=None):
        self.value = value
        self.description = description or "boolean %s" % value

    def __call__(self, config):
        return self.value

    def _as_string(self, config, negate=False):
        return self._format_description(config, negate=negate)


class SpecPredicate(Predicate):
    def __init__(self, db, op=None, spec=None, description=None):
        self.db = db
        self.op = op
        self.spec = spec
        self.description = description

    _ops = {
        "<": operator.lt,
        ">": operator.gt,
        "==": operator.eq,
        "!=": operator.ne,
        "<=": operator.le,
        ">=": operator.ge,
        "in": operator.contains,
        "between": lambda val, pair: val >= pair[0] and val <= pair[1],
    }

    def __call__(self, config):
        engine = config.db

        if "+" in self.db:
            dialect, driver = self.db.split("+")
        else:
            dialect, driver = self.db, None

        if dialect and engine.name != dialect:
            return False
        if driver is not None and engine.driver != driver:
            return False

        if self.op is not None:
            assert driver is None, "DBAPI version specs not supported yet"

            version = _server_version(engine)
            oper = (
                hasattr(self.op, "__call__") and self.op or self._ops[self.op]
            )
            return oper(version, self.spec)
        else:
            return True

    def _as_string(self, config, negate=False):
        if self.description is not None:
            return self._format_description(config)
        elif self.op is None:
            if negate:
                return "not %s" % self.db
            else:
                return "%s" % self.db
        else:
            if negate:
                return "not %s %s %s" % (self.db, self.op, self.spec)
            else:
                return "%s %s %s" % (self.db, self.op, self.spec)


class LambdaPredicate(Predicate):
    def __init__(self, lambda_, description=None, args=None, kw=None):
        spec = inspect_getargspec(lambda_)
        if not spec[0]:
            self.lambda_ = lambda db: lambda_()
        else:
            self.lambda_ = lambda_
        self.args = args or ()
        self.kw = kw or {}
        if description:
            self.description = description
        elif lambda_.__doc__:
            self.description = lambda_.__doc__
        else:
            self.description = "custom function"

    def __call__(self, config):
        return self.lambda_(config)

    def _as_string(self, config, negate=False):
        return self._format_description(config)


class NotPredicate(Predicate):
    def __init__(self, predicate, description=None):
        self.predicate = predicate
        self.description = description

    def __call__(self, config):
        return not self.predicate(config)

    def _as_string(self, config, negate=False):
        if self.description:
            return self._format_description(config, not negate)
        else:
            return self.predicate._as_string(config, not negate)


class OrPredicate(Predicate):
    def __init__(self, predicates, description=None):
        self.predicates = predicates
        self.description = description

    def __call__(self, config):
        for pred in self.predicates:
            if pred(config):
                return True
        return False

    def _eval_str(self, config, negate=False):
        if negate:
            conjunction = " and "
        else:
            conjunction = " or "
        return conjunction.join(
            p._as_string(config, negate=negate) for p in self.predicates
        )

    def _negation_str(self, config):
        if self.description is not None:
            return "Not " + self._format_description(config)
        else:
            return self._eval_str(config, negate=True)

    def _as_string(self, config, negate=False):
        if negate:
            return self._negation_str(config)
        else:
            if self.description is not None:
                return self._format_description(config)
            else:
                return self._eval_str(config)


_as_predicate = Predicate.as_predicate


def _is_excluded(db, op, spec):
    return SpecPredicate(db, op, spec)(config._current)


def _server_version(engine):
    """Return a server_version_info tuple."""

    # force metadata to be retrieved
    conn = engine.connect()
    version = getattr(engine.dialect, "server_version_info", None)
    if version is None:
        version = ()
    conn.close()
    return version


def db_spec(*dbs):
    return OrPredicate([Predicate.as_predicate(db) for db in dbs])


def open():  # noqa
    return skip_if(BooleanPredicate(False, "mark as execute"))


def closed():
    return skip_if(BooleanPredicate(True, "marked as skip"))


def fails(reason=None):
    return fails_if(BooleanPredicate(True, reason or "expected to fail"))


@decorator
def future(fn, *arg):
    return fails_if(LambdaPredicate(fn), "Future feature")


def fails_on(db, reason=None):
    return fails_if(db, reason)


def fails_on_everything_except(*dbs):
    return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))


def skip(db, reason=None):
    return skip_if(db, reason)


def only_on(dbs, reason=None):
    return only_if(
        OrPredicate(
            [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)]
        )
    )


def exclude(db, op, spec, reason=None):
    return skip_if(SpecPredicate(db, op, spec), reason)


def against(config, *queries):
    assert queries, "no queries sent!"
    return OrPredicate([Predicate.as_predicate(query) for query in queries])(
        config
    )
